Machine learning
Setup the environment if this is executed on Google Colab.
Make sure to change the runtime type to GPU. To do this go to Runtime -> Change runtime type -> GPU
Otherwise, rendering won't work in Google Colab.
import os
try:
import google.colab
IN_COLAB = True
except:
IN_COLAB = False
if IN_COLAB:
os.system("pip install --quiet 'x_xy[all_muj] @ git+https://github.com/SimiPixel/x_xy_v2'")
os.system("pip install --quiet mediapy")
import x_xy
# automatically detects colab or not
x_xy.utils.setup_colab_env()
from x_xy.subpkgs import ml, exp, benchmark, sys_composer, sim2real, omc
import mediapy
import jax.numpy as jnp
import tree_utils
import jax
def load_systems():
sys = exp.load_sys("S_04", morph_yaml_key="seg2", delete_after_morph=["seg5", "imu3"])
sys_noimu, _ = sys_composer.make_sys_noimu(sys)
def _geoms_replace_color(sys: x_xy.System, color):
link_idx_to_root = 0
geoms = [g.replace(color=color) for g in sys.geoms if g.link_idx != link_idx_to_root]
return sys.replace(geoms=geoms)
# replace render color of geoms for render of predicted motion
prediction_color = (78 / 255, 163 / 255, 243 / 255, 1.0)
sys_newcolor = _geoms_replace_color(sys_noimu, prediction_color)
sys_render = sys_composer.inject_system(sys, sys_newcolor.add_prefix_suffix("hat_"))
return sys, sys_noimu, sys_render
def load_data_and_prediction(motion, sys, sys_noimu, params):
exp_data = exp.load_data("S_04", motion)
xml_str = exp.load_xml_str("S_04")
xs = sim2real.xs_from_raw(sys, exp.link_name_pos_rot_data(exp_data, xml_str), qinv=True)
# slightly decrease `transform1.pos.x` by a little
translations, rotations = sim2real.unzip_xs(sys, xs)
seg_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] != "imu"])
imu_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] == "imu"])
translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.03))
translations = translations.replace(pos=translations.pos.at[:, imu_mask, 0].set(translations.pos[:, imu_mask, 0] + 0.03))
if sys.link_parents[sys.name_to_idx("seg2")] != -1:
# a little extra for seg2
seg_mask = jnp.array([sys.name_to_idx("seg2")])
translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.02))
xs_translated = sim2real.zip_xs(sys, translations, rotations)
X = {seg: {} for seg in ["seg2", "seg3", "seg4"]}
for seg in X:
imu_data = exp_data[seg]["imu_rigid"]
imu_data.pop("mag")
if seg == "seg3":
imu_data = tree_utils.tree_zeros_like(imu_data)
X[seg].update(imu_data)
y = x_xy.rel_pose(sys_noimu, xs, sys)
filter = ml.RNNOFilter(params=params)
filter.init(sys_noimu, tree_utils.tree_slice(X, 0))
yhat = tree_utils.tree_slice(filter.predict(tree_utils.add_batch_dim(X)), 0)
return xs_translated, X, y, yhat
def render(sys, sys_noimu, sys_render, xs, yhat):
xs_noimu = sim2real.match_xs(sys_noimu, xs, sys)
# `yhat` are child-to-parent transforms, but we need parent-to-child
# this dictonary has now all links that don't connect to worldbody
transform2hat_rot = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), yhat)
transform1, transform2 = sim2real.unzip_xs(sys_noimu, xs_noimu)
# we add the missing links in transform2hat, links that connect to worldbody
transform2hat = []
for i, name in enumerate(sys_noimu.link_names):
if name in transform2hat_rot:
transform2_name = x_xy.Transform.create(rot=transform2hat_rot[name])
else:
transform2_name = transform2.take(i, axis=1)
transform2hat.append(transform2_name)
# after transpose shape is (n_timesteps, n_links, ...)
transform2hat = transform2hat[0].batch(*transform2hat[1:]).transpose((1, 0, 2))
xshat = sim2real.zip_xs(sys_noimu, transform1, transform2hat)
# swap time axis, and link axis
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
# create mapping from `name` -> Transform
xs_dict = dict(
zip(
["hat_" + name for name in sys_noimu.link_names],
[xshat[i] for i in range(sys_noimu.num_links())],
)
)
xs_dict.update(
dict(
zip(
sys.link_names,
[xs[i] for i in range(sys.num_links())],
)
)
)
xs_render = []
for name in sys_render.link_names:
xs_render.append(xs_dict[name])
xs_render = xs_render[0].batch(*xs_render[1:])
xs_render = xs_render.transpose((1, 0, 2))
N = xs_render.shape()
xs_render = [xs_render[t] for t in range(0, N, 4)]
frames = x_xy.render(sys_render, xs_render, width=640, height=480, camera="c",
add_cameras={-1: '<camera name="c" mode="targetbody" target="3" pos=".5 -.5 1.25"/>',})
return frames
params = ml.load(pretrained="rr_rr_unknown")
motion = "thomas_fast"
sys, sys_noimu, sys_render = load_systems()
xs, X, y, yhat = load_data_and_prediction(motion, sys, sys_noimu, params)
frames = render(sys, sys_noimu, sys_render, xs, yhat)
mediapy.show_video(frames, fps=25.0)